Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Save and restore stateful callbacks as part of experiment checkpoint #31957

Merged
merged 15 commits into from
Feb 1, 2023

Conversation

justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Jan 26, 2023

Why are these changes needed?

This PR allows Callbacks to implement the get_state and set_state methods for save/restore. During experiment checkpoints, Tune will save callback states, if any callbacks implement get_state. During experiment restore, Tune will load the latest callback states.

Context

Callback state is not currently saved as part of experiment state. For user defined stateful callbacks, this means that their callbacks will not continue from where they left off on experiment restoration.

What's the benefit of doing this instead of just pickling the callbacks along with the TrialRunner state? We shift the responsibility of save/load to the user instead of doing everything in Tune. For example, this allows users to re-create non-serializable objects (ex: threading objects), if used in the callback.

class StatefulCallback(Callback):
    def __init__(self):
        self.obj_ref = Model.options(name="shared_model").remote()
        self.nonserializable = threading.Lock()
        self.state = 0

    def get_state(self):
        return {"state": self.state, "model_state": ray.get(self.obj_ref).get_state()}

    def set_state(self, state):
        self.state = state["state"]
        # Re-create some named actor that's accessible by your trainable
        restored_model = Model.options(name="shared_model).remote()
        ray.get(restored_model.set_state.remote(state["model_state"]))
        self.obj_ref = restored_model
        self.nonserializable = threading.Lock()

@ray.remote
class Model:
    def get_state(self):
        pass
    def set_state(self, state):
        pass

TODO

  • Add example stateful callback + get/setstate to the docs. Here and here.
  • In a future PR, change Searcher/SearchAlgorithm to use this common Restorable interface.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main question I have is why we need a new Restorable interface, when most of the functionality is provided by python already.

I.e. there is obj.__getstate__() and obj.__setstate__(), and it's automatically used for pickling and unpickling. Shouldn't users just implement those?

For the callback list or other locations where we pickle callbacks, we can catch serialization errors and provide an actionable error message.

It feels like we're trying to solve a problem that's already solved. If we went with Restorable, would the trial runner, searcher, etc also eventually inherit this interface?

@justinvyu
Copy link
Contributor Author

justinvyu commented Jan 30, 2023

Just using the python __getstate__ and __setstate__ was something that I was initially considering, but I went for the get_state/set_state route because other things like searchers and schedulers have this interface. Then, I was thinking about having this Restorable interface be common across all of them including callbacks.

But, I agree that that this is not needed and introduces an unnecessary new interface. We can move to just pickling these objects directly and implementing __getstate__ and __setstate__ for all of them rather than having a separate save method that just pickles the entire object __dict__ anyways. Will update this PR.

Update: We will go with the current get_state and set_state, but not introduce a general Restorable interface, for the following reasons:

  1. Callbacks should be stateless by default, and we don't want to pickle them all unnecessarily (the default __getstate__ would save everything).
  2. We technically should be loading state on experiment restoration rather than replacing the entire callback object. We first load the pickled callbacks (as they were initialized and passed into the Tuner), and then we load the state.
  3. We shouldn't introduce a Restorable interface, because it's unclear how this will be used in the future, and more discussion is needed before introducing something that multiple components will implement. It's also confusing to users/developers when to use this interface.

@justinvyu justinvyu requested a review from krfricke January 31, 2023 23:28
try:
callback_state = callback.get_state()
any_stateful_callbacks = True
except NotImplementedError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check for AttributeError?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have replaced the get_state with a default of returning None (stateless). No need to check for this error anymore. I think an attribute error should get raised immediately instead of caught.

python/ray/tune/callback.py Show resolved Hide resolved
Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
@justinvyu justinvyu requested a review from Yard1 February 1, 2023 01:33
@gjoliver gjoliver merged commit 890e034 into ray-project:master Feb 1, 2023
edoakes pushed a commit to edoakes/ray that referenced this pull request Mar 22, 2023
…kpoint (ray-project#31957)

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants